import json
import random
import pathlib
import re
from typing import Union, List, Dict, Optional

# ------------------------- config -------------------------
NUM_PROMPTS: int = 64
SEED: int = 42
TASK_CONFIG_DIR: str = "<path_to_your_task_configs>" 
NUM_TASKS: int = 15
# ------------------------- end config -------------------------

def parse_sh_config(filepath: Union[str, pathlib.Path]) -> Dict[str, str]:
    """
    Parse a shell script configuration file and extract key-value pairs.
    Returns a dictionary with the configuration values.
    """
    config = {}
    try:
        with open(filepath, 'r') as f:
            for line in f:
                line = line.strip()
                # Skip comments and empty lines
                if line.startswith('#') or not line:
                    continue
                
                # Match lines like: VARIABLE="value" or VARIABLE=value
                match = re.match(r'^([A-Z_]+)=["\']*(.+?)["\']*$', line)
                if match:
                    key, value = match.groups()
                    # Remove quotes if present
                    value = value.strip('"').strip("'")
                    config[key] = value
    except FileNotFoundError:
        print(f"Warning: Config file {filepath} not found")
    
    return config

def build_class_prompts_with_base(base_prompt: str, num: int = 64, seed: int = 0) -> List[str]:
    """
    Return *num* text prompts starting with the base_prompt and adding variations.
    For example:
        'a photo of a <sks> dog, on a snowy mountain at sunset, ...'
    """
    random.seed(seed)

    scenes = [
        "on a sunny beach", "in a lush green park", "inside a cozy living-room",
        "on a snowy mountain at sunset", "running across a golden wheat field",
        "under cherry-blossom trees", "posing in front of the Eiffel Tower",
        "sitting on a skateboard at a skatepark", "playing fetch by a forest lake",
        "splashing in a backyard pool", "walking through bustling city streets",
        "lying on a vintage Persian rug", "climbing a rocky cliff by the sea",
        "resting beside a campfire", "sitting in a canoe on calm water",
        "wrapped in a warm blanket during a snowfall"
    ]

    times = [
        "at dawn", "at golden hour", "at blue hour", "on an overcast day",
        "under neon lights at night", "at high noon", "during a thunderstorm",
        "at twilight"
    ]

    styles = [
        "high-resolution photograph", "analog film style", "soft focus portrait",
        "35 mm film", "ultra-wide angle shot", "macro shot", "HDR", "vintage Polaroid",
        "cinematic still", "professional studio lighting", "low-key lighting",
        "minimalist composition", "aerial drone view"
    ]

    actions = [
        "wearing sunglasses", "jumping over an obstacle", "catching a frisbee",
        "tilting its head", "sleeping peacefully", "wagging its tail",
        "smiling at the camera", "with its tongue out", "howling playfully",
        "sniffing a flower", "splashing water", "wrapped in a scarf"
    ]
    
    def sample_prompt() -> str:
        # Start with the exact base prompt from the config
        prompt_parts = [base_prompt]
        
        # Add 3-4 random flavor elements
        # Randomly shuffle and pick elements to add variety
        elements = []
        if random.random() > 0.2:  # 80% chance to include scene
            elements.append(random.choice(scenes))
        if random.random() > 0.3:  # 70% chance to include time
            elements.append(random.choice(times))
        if random.random() > 0.2:  # 80% chance to include style
            elements.append(random.choice(styles))
        if random.random() > 0.4:  # 60% chance to include action
            elements.append(random.choice(actions))
        
        # Ensure we have at least 2 elements for variety
        while len(elements) < 2:
            category = random.choice([scenes, times, styles, actions])
            element = random.choice(category)
            if element not in elements:
                elements.append(element)
        
        prompt_parts.extend(elements)
        return ", ".join(prompt_parts)

    return [sample_prompt() for _ in range(num)]


def save_prompts_json(prompts: List[str], out_path: Union[str, pathlib.Path]) -> None:
    """
    Save *prompts* in the COCO-Captions-style structure.
    """
    data = {"annotations": [{"caption": p} for p in prompts]}
    pathlib.Path(out_path).write_text(
        json.dumps(data, indent=2, ensure_ascii=False), encoding="utf-8"
    )


def extract_concept_from_prompt(instance_prompt: str) -> str:
    """
    Extract a concept name from the instance prompt.
    E.g., "a photo of a <sks> dog" -> "dog"
    """
    # Try to find the word after <sks>
    match = re.search(r'<sks>\s+(\w+)', instance_prompt)
    if match:
        return match.group(1)
    
    # Fallback: try to find any common concept words
    concepts = ['dog', 'cat', 'car', 'person', 'object', 'style', 'building', 'flower']
    for concept in concepts:
        if concept in instance_prompt.lower():
            return concept
    
    return "concept"  # Default fallback


# ------------------------- main execution -------------------------
if __name__ == "__main__":
    # Process each task configuration
    for task_num in range(1, NUM_TASKS + 1):
        config_path = pathlib.Path(TASK_CONFIG_DIR) / f"task{task_num}.sh"
        
        print(f"\nProcessing {config_path}...")
        
        # Parse the configuration file
        config = parse_sh_config(config_path)
        
        if "INSTANCE_PROMPT" not in config:
            print(f"  Warning: No INSTANCE_PROMPT found in {config_path}, skipping...")
            continue
        
        instance_prompt = config["INSTANCE_PROMPT"]
        concept_name = config.get("CONCEPT_NAME", "")
        
        # If no concept name in config, try to extract from prompt
        if not concept_name:
            concept_name = extract_concept_from_prompt(instance_prompt)
        else:
            # Clean up concept name (e.g., "1_dog" -> "dog" for filename)
            concept_name = concept_name.split('_')[-1] if '_' in concept_name else concept_name
        
        print(f"  Instance prompt: {instance_prompt}")
        print(f"  Concept: {concept_name}")
        
        # Generate prompts using the exact instance prompt as base
        prompts = build_class_prompts_with_base(instance_prompt, NUM_PROMPTS, SEED + task_num)
        
        # Save to JSON file
        out_file = f"calibration_task{task_num}.json"
        save_prompts_json(prompts, out_file)
        
        print(f"  Wrote {len(prompts)} prompts to {out_file}")
        
        # Show a few examples
        print(f"  Sample prompts:")
        for i, prompt in enumerate(prompts[:3]):
            print(f"    {i+1}. {prompt}")
    
    print(f"\nCompleted processing {NUM_TASKS} task configurations!")